import argparse
import logging
from collections import defaultdict
from pathlib import Path

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import wandb

from disvae.models.losses import _kl_normal_loss, get_loss_f
from disvae.utils.initialization import weights_init
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
import matplotlib.pyplot as plt
import numpy as np

from utils.helpers import set_seed

wandb.init(project="offline-demo", )

set_seed(112)
formatter = logging.Formatter('%(asctime)s %(levelname)s - %(funcName)s: %(message)s',
                              "%H:%M:%S")
logger = logging.getLogger(__name__)
logger.setLevel('INFO')

stream = logging.StreamHandler()
stream.setLevel('INFO')
stream.setFormatter(formatter)
logger.addHandler(stream)

dataset_zip = torch.load('data/dsprites/dsprites_move.pt')
imgs = dataset_zip['imgs']
latents_values = dataset_zip['latents_values'][:, [4, 5]]


# corrupt_set = torch.cat([imgs[:32 * 16], ] + [imgs[32 * 16 + i * 32:
#                                                    32 * 16 + i * 32 + 16] for i in range(16)], dim=0)
#
# unseen_set = torch.cat([imgs[32 * 16 + i * 32 + 16:
#                              32 * 16 + i * 32 + 32] for i in range(16)], dim=0)
#
# # torch.cat([list(range(32 * 16 + i * 32 + 16, 32 * 16 + i * 32 + 32)) for i in range(16)], dim=0)
#
# diagonal_set = torch.cat([imgs[32 * i + i] for i in range(32)], dim=0)
#
# y_set = imgs[:32]


class VAE(nn.Module):
    def __init__(self, img_size, encoder, decoder, latent_dim, group=1):
        """
        Class which defines model and forward pass.

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32) or (3, 64, 64).
        """
        super(VAE, self).__init__()

        if list(img_size[1:]) not in [[32, 32], [64, 64]]:
            raise RuntimeError(
                "{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(
                    img_size))

        self.latent_dim = latent_dim
        self.group = group
        self.img_size = img_size
        self.num_pixels = self.img_size[1] * self.img_size[2]
        self.encoder = encoder(img_size, self.latent_dim * group)
        self.decoder = decoder(img_size, self.latent_dim)

        hidden = 512
        self.discriminator = nn.Sequential(
            nn.Linear(img_size[1] * img_size[2], hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, 1),
        )

        self.reset_parameters()

    def reparameterize(self, mean, logvar):
        """
        Samples from a normal distribution using the reparameterization trick.

        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (batch_size, latent_dim)

        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (batch_size,
            latent_dim)
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + std * eps
        else:
            # Reconstruction mode
            return mean

    def forward(self, x):
        """
        Forward pass of model.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)
        # grouped z = sum(xi)
        latent_sample1 = latent_sample.view(x.size(0),
                                            self.latent_dim, self.group).sum(2)

        reconstruct = self.decoder(latent_sample1)
        return reconstruct, latent_dist, latent_sample

    def reset_parameters(self):
        self.apply(weights_init)

    def sample_latent(self, x):
        """
        Return latent distribution and samples.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)
        return latent_dist, latent_sample


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), 1e-3)
    for e in range(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img in dl:
            img = img.view(-1, 1, 64, 64)
            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            loss_set.append(loss.item())
        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)

        wandb.log(storer, sync=False, step=e)
        logger.info(';'.join([f'{k}:{v}' for k, v in storer.items()]))

        if e % 50 == 0:
            fig = plot_trajectory(dataset, vae)
            fig.savefig(path.joinpath(f'trajectory_{e // 50}.png'))
        with torch.no_grad():
            mu, log_var = latent_dist
            var = log_var.exp().cpu()
            mu = mu.cpu()
            for d in range(dim):
                wandb.log({f"mu_{d}": wandb.Histogram(mu[:, d])}, step=e, sync=False)
                wandb.log({f"var_{d}": wandb.Histogram(var[:, d])}, step=e, sync=False)


def train_vaegan(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), 1e-3)
    for e in range(epochs):
        storer = defaultdict(list)
        for img in dl:
            model.decoder.requires_grad_(False)
            # reconstruct, opt encoder only
            model.encoder.requires_grad_(True)
            batch_size = img.size(0)
            img = img.view(-1, 1, 64, 64)
            recon, latent_dist, latent_codes = model(img)  # mu当作编码，忽略log_var
            kl = _kl_normal_loss(*latent_dist)
            recon_loss = F.binary_cross_entropy(recon, img, reduction='sum')
            vae_loss = recon_loss + kl

            opt.zero_grad()
            vae_loss.backward()
            opt.step()
            storer['recon_loss'].append(recon_loss.item())
            storer['kl'].append(kl.item())

            model.discriminator.requires_grad_(False)
            # gen，opt decoder only
            model.decoder.requires_grad_(True)

            recon = model.decoder(latent_codes.data)
            dis_vae = model.discriminator(recon.view(batch_size, -1))
            gen_codes = (torch.randn_like(latent_codes))
            gen_imgs = model.decoder(gen_codes)
            dis_gan = model.discriminator(gen_imgs.view(batch_size, -1))
            disc_loss = (dis_vae + dis_gan).mean()

            recon_loss = F.binary_cross_entropy(recon, img, reduction='sum')
            loss = recon_loss - disc_loss

            opt.zero_grad()
            loss.backward()
            opt.step()
            storer['disc_loss'].append(disc_loss.item())

            model.decoder.requires_grad_(False)
            # discriminator, opt discriminator only
            model.discriminator.requires_grad_(True)
            recon = recon.view(-1, 64 * 64).data
            # dis x_vae
            dis_vae = model.discriminator(recon)
            # dis x_real
            dis_real = model.discriminator(img.view(-1, 64 * 64))
            # gan 部分，x_gan
            gen_codes = (torch.randn_like(latent_codes))
            gen_imgs = model.decoder(gen_codes).data
            dis_gan = model.discriminator(gen_imgs.view(batch_size, -1))

            loss = (-dis_real + dis_gan + dis_vae).mean() + 3

            opt.zero_grad()
            loss.backward()
            opt.step()
            storer['disc_loss1'].append(loss.item())


        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['epoch'] = e
        logger.info(';'.join([f'{k}:{v}' for k, v in storer.items()]))


def traversal_latents(base_latent, traversal_vector, dim):
    l = len(traversal_vector)
    traversals = base_latent.repeat(l, 1)
    traversals[:, dim] = traversal_vector
    return traversals


def rand_sample(ds, sz=1):
    rand_order = torch.randperm(len(ds))
    return ds[rand_order[:sz]]


def plot_bar(axes, images):
    for ax, img in zip(axes, images):
        ax.imshow(img.numpy(), cmap='gray')
        ax.axis('off')


def plot_reconstruct(dataset, size, model):
    r, c = size
    with torch.no_grad():
        img = rand_sample(dataset, r * c)
        img = img.view(-1, 1, 64, 64)
        recon_batch, latent_dist, latent_sample = model(img)

        fig, axes = plt.subplots(r, c * 2, figsize=(c * 2 * 4, r * 4))
        plt.tight_layout()
        for i in range(c):
            plot_bar(axes[:, i * c], img[r * i:r * i + r, 0])
            axes[0, i * c].set_title('origin')
            plot_bar(axes[:, i * c + 1], recon_batch[r * i:r * i + r, 0])
            axes[0, i * c + 1].set_title('recon')

        return fig


def plot_rand_reconstruct(size, model):
    r, c = size
    with torch.no_grad():
        latents_values = torch.randn(r * c, model.latent_dim)
        recon_batch = model.decoder(latents_values)

        fig, axes = plt.subplots(r, c, figsize=(c * 4, r * 4))
        plt.tight_layout()
        for img, ax in zip(recon_batch, axes.flatten()):
            ax.imshow(img.reshape(64, 64), cmap='gray')
            ax.axis('off')
        return fig

def plt_sample_traversal(sample, model, traversal_len=5, dim_len=4):
    with torch.no_grad():
        mu, log_var = model.encoder(sample)
        fig, axes = plt.subplots(dim_len, traversal_len,
                                 figsize=(traversal_len * 4, dim_len * 4))
        plt.tight_layout()

        for dim in range(dim_len):
            base_latents = mu
            # sigma = (log_var[0, dim] / 2).exp().item()
            sigma = 1
            linear_traversal = torch.linspace(-3 * sigma, 3 * sigma, traversal_len)
            traversals = traversal_latents(base_latents, linear_traversal, dim)
            recon_batch = model.decoder(traversals)

            plot_bar(axes[dim], recon_batch[:, 0])

        return fig


cm = plt.cm.get_cmap('RdYlBu')


def plot_trajectory(sample, model):
    fig = plt.figure()
    with torch.no_grad():
        mu, log_var = model.encoder(sample.view(-1, 1, 64, 64))
        mu = mu.cpu()
        sc = plt.scatter(mu[:32, 0].data, mu[:32, 1].data, c=range(32), )
        plt.colorbar(sc)
        sc = plt.scatter(mu[::32, 0].data, mu[::32, 1].data, c=range(32), cmap=cm)
        plt.colorbar(sc)
    return fig


exp_dir = Path('results/')

splines = imgs * torch.cat([torch.ones(1, 64), torch.zeros(1, 64)]).repeat(32, 1)

epochs = 250
dim = 4
for ds in [imgs]:

    dataset = ds.cuda()
    dl = DataLoader(dataset, num_workers=0, batch_size=64, shuffle=True)

    for h in [16]:
        wandb.config.latent_dim = dim
        wandb.config.betaH_B = h
        wandb.config.epochs = epochs
        wandb.config.batch_size = 64
        if ds is imgs:
            path = exp_dir.joinpath(f'normal_{h}')
        else:
            path = exp_dir.joinpath(f'splines_{h}')
        # if path.exists():
        #     continue
        path.mkdir(exist_ok=True)

        vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim, 1)
        vae.cuda()
        loss_f = get_loss_f('Dis', betaH_B=h,
                            record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))

        train(dl, vae, loss_f, epochs)

        torch.save(vae.state_dict(), path.joinpath('model.pt'))
        vae.cpu()

        fig = plot_reconstruct(imgs, (3, 2), vae)
        fig.savefig(path / 'reconstruction.png')
        fig = plt_sample_traversal(imgs[:1].view(1, 1, 64, 64), vae, 7, dim)
        fig.savefig(path / 'traversal.png')

wandb.join()
# import pandas as pd
#
# lines = open(path.joinpath('exp.log')).readlines()
# df = pd.DataFrame([dict([item.split(':') for item in line[18:-1].split(';')]) for line in lines])
# linear_traversal = torch.linspace(-1, 1 , 7)
# traversals = traversal_latents(mu[:1], linear_traversal, 2)
# recon_batch = model.decoder(traversals)
# fig, axes = plt.subplots(1, 7,figsize=(4*7,4))
# plot_bar(axes,recon_batch[:, 0].detach());plt.show()
